import torch
import torch.nn as nn
import numpy as np
from rlcore.distributions import Categorical
import torch.nn.functional as F
import math
from layers import GraphAttentionLayer
from tensorboardX import SummaryWriter


def weights_init(m):
    classname = m.__class__.__name__
    if classname.find('Conv') != -1 or classname.find('Linear') != -1:
        nn.init.orthogonal_(m.weight.data)
        if m.bias is not None:
            m.bias.data.fill_(0)


class MPNN(nn.Module):
    def __init__(self, action_space, num_agents, num_entities, input_size=16, hidden_dim=128, embed_dim=None,
                 pos_index=2, norm_in=False, nonlin=nn.ReLU, n_heads=1, mask_dist=None, entity_mp=False):
        super().__init__()

        self.h_dim = hidden_dim
        #self.h_dim = 64
        self.nonlin = nonlin
        self.num_agents = num_agents # number of agents
        self.num_entities = num_entities # number of entities
        self.K = 3 # message passing rounds
        self.embed_dim = self.h_dim if embed_dim is None else embed_dim
        self.n_heads = n_heads
        self.mask_dist = mask_dist
        self.input_size = input_size
        self.entity_mp = entity_mp
        # this index must be from the beginning of observation vector
        self.pos_index = pos_index

        self.encoder = nn.Sequential(nn.Linear(self.input_size,self.h_dim),
                                     self.nonlin(inplace=True))

        self.messages = MultiHeadAttention(n_heads=self.n_heads,input_dim=self.h_dim,embed_dim=self.embed_dim)

        self.attentions = [GraphAttentionLayer(self.h_dim, self.h_dim, dropout=0.6, alpha=0.2, concat=True) for _ in range(self.K)]
        for i, attention in enumerate(self.attentions):
            self.add_module('attention_{}'.format(i), attention)
        self.attention_value = Attention_value(hidden_dim=128, embed_dim=None, nonlin=nn.ReLU)

        #self.attentions2 = [GraphAttentionLayer(self.h_dim, self.h_dim, dropout=0.6, alpha=0.2, concat=True) for _ in range(self.K)]
        #for i, attention in enumerate(self.attentions2):
        #    self.add_module('attention2_{}'.format(i), attention)

        #self.attentions3 = [GraphAttentionLayer(self.h_dim, self.h_dim, dropout=0.6, alpha=0.2, concat=True) for _ in range(self.K)]
        #for i, attention in enumerate(self.attentions3):
        #    self.add_module('attention3_{}'.format(i), attention)
        #self.attentions = GraphAttentionLayer(self.h_dim, self.h_dim, dropout=0.6, alpha=0.2, concat=True)

        self.update = nn.Sequential(nn.Linear(self.h_dim+self.embed_dim,self.h_dim),
                                    self.nonlin(inplace=True))
        #self.update = nn.Sequential(nn.Linear(self.K * self.h_dim,self.h_dim),
        #                            self.nonlin(inplace=True))
        #self.update2 = nn.Sequential(nn.Linear(self.K * self.h_dim,self.h_dim),
        #                            self.nonlin(inplace=True))
        #self.update3 = nn.Sequential(nn.Linear(self.K * self.h_dim,self.h_dim),
        #                            self.nonlin(inplace=True))

        self.value_head = nn.Sequential(nn.Linear(self.h_dim, self.h_dim),         #  actor_critic网络 无需修改
                                        self.nonlin(inplace=True),
                                        nn.Linear(self.h_dim,1))

        self.policy_head = nn.Sequential(nn.Linear(self.h_dim, self.h_dim),        #actor_critic网络 无需修改
                                         self.nonlin(inplace=True))

        if self.entity_mp:
            self.entity_encoder = nn.Sequential(nn.Linear(2,self.h_dim),
                                                self.nonlin(inplace=True))

            self.entity_messages = MultiHeadAttention(n_heads=1,input_dim=self.h_dim,embed_dim=self.embed_dim)

            self.entity_update = nn.Sequential(nn.Linear(self.h_dim+self.embed_dim,self.h_dim),
                                               self.nonlin(inplace=True))

        num_actions = action_space.n
        self.dist = Categorical(self.h_dim,num_actions)

        self.is_recurrent = False

        if norm_in:
            self.in_fn = nn.BatchNorm1d(self.input_size)
            self.in_fn.weight.data.fill_(1)
            self.in_fn.bias.data.fill_(0)
        else:
            self.in_fn = lambda x: x
        self.apply(weights_init)

        self.attn_mat = np.ones((num_agents, num_agents))

        self.dropout_mask = None

    def calculate_mask(self, inp):
        # inp is batch_size x self.input_size where batch_size is num_processes*num_agents

        pos = inp[:, self.pos_index:self.pos_index+2]
        bsz = inp.size(0)//self.num_agents
        mask = torch.full(size=(bsz,self.num_agents,self.num_agents),fill_value=0,dtype=torch.float)
        #print(mask.dtype)
        #print('fuck')
        mask2 = torch.full(size=(bsz,self.num_agents,self.num_agents),fill_value=1,dtype=torch.float)
        #print(mask2)
        #print('fuck')
        mask3 = np.eye(self.num_agents,dtype=float)
        mask3 = torch.FloatTensor(mask3)
        #mask4 = np.eye(self.num_agents,dtype=int)
        #print(mask4.shape)
        #mask4 = torch.FloatTensor(mask4)
        #print(mask4.shape)
        #for i in range(bsz-1):
        #    mask4 = torch.concatenate((mask4,mask3),axis = 0)
        #   mask4 = np.stack(mask3,axis=0) for i in range(bsz)
        mask4 = np.tile(mask3,(bsz,1,1))
        #print(mask4.dtype)
        #print(mask4)
        #mask3=np.stack((states,states,states),axis=0, for i in range(bsz))
        mask4 = torch.FloatTensor(mask4)
        #print(mask4)
        #print(mask4.shape)
        #print('fuck2')


        if self.mask_dist is not None and self.mask_dist > 0:
            for i in range(1,self.num_agents):
                shifted = torch.roll(pos,-bsz*i,0)
                dists = torch.norm(pos-shifted,dim=1)
                restrict = dists > self.mask_dist
                for x in range(self.num_agents):
                    mask[:,x,(x+i)%self.num_agents].copy_(restrict[bsz*x:bsz*(x+1)])

        elif self.mask_dist is not None and self.mask_dist == -10:
           if self.dropout_mask is None or bsz!=self.dropout_mask.shape[0] or np.random.random_sample() < 0.1: # sample new dropout mask
               temp = torch.rand(mask.size()) > 0.85
               temp.diagonal(dim1=1,dim2=2).fill_(0)
               self.dropout_mask = (temp+temp.transpose(1,2))!=0
           mask.copy_(self.dropout_mask)

        #mask = mask2 - mask - mask4.float()
        mask = mask2 - mask
        return mask


    def _fwd(self, inp):
        # inp should be (batch_size,input_size)   [54, 6]          54是样本数量，6是特征维度数量
        # inp - {iden, vel(2), pos(2), entities(...)}              inp顺序下来是  agent位置，速度，entity位置，obstacle位置
        #print(inp.shape)
        #print('fuck')
        #writer = SummaryWriter(args.log_dir)
        agent_inp = inp[:,:self.input_size]
        mask = self.calculate_mask(agent_inp) # shape <batch_size/N,N,N> with 0 for comm allowed, 1 for restricted
        #print(mask)

        h = self.encoder(agent_inp) # should be (batch_size,self.h_dim)
        if self.entity_mp:
            landmark_inp = inp[:,self.input_size:] # x,y pos of landmarks wrt agents
            # should be (batch_size,self.num_entities,self.h_dim)
            he = self.entity_encoder(landmark_inp.contiguous().view(-1,2)).view(-1,self.num_entities,self.h_dim)
            entity_message = self.entity_messages(h.unsqueeze(1),he).squeeze(1) # should be (batch_size,self.h_dim)
            h = self.entity_update(torch.cat((h,entity_message),1)) # should be (batch_size,self.h_dim)

        h = h.view(self.num_agents,-1,self.h_dim).transpose(0,1) # should be (batch_size/N,N,self.h_dim)


        #adj = np.matrix([[0,1,1],[1,0,1],[1,1,0]])
        #adj = np.matrix([[0,1,1,1,1,1],[1,0,1,1,1,1],[1,1,0,1,1,1],[1,1,1,0,1,1],[1,1,1,1,0,1],[1,1,1,1,1,0]])
        #adj = np.matrix([[1,1,0,0,0,1],[1,1,1,0,0,0],[0,1,1,1,0,0],[0,0,1,1,1,0],[0,0,0,1,1,1],[1,0,0,0,1,1]])
        #adj = torch.FloatTensor(adj)
        for att in self.attentions:
            m,attn = att(h,mask)
            h = self.update(torch.cat((h,m),2))
        #ccc = []
        #for att in self.attentions:
        #    xx,attn = att(h, adj)
        #    ccc.append(xx)
        #print(len(ccc))
        #xxx = ccc[0]
        #for i in range(1,len(ccc)):
        #    xxx = xxx +  ccc[i]          # 不能用x+=1,因为这是一个in-place操作 x = x + 1 is not in-place, because it takes the objects pointed to by x, creates a new Variable, adds 1 to x putting the result in the new Variable, and overwrites the object referenced by x to point to the new var. There are no in-place modifications, you only change Python references (you can check that id(x) is different before and after that line).
        #m = xxx
        #m = m / self.K
        #m,attn = self.attentions(h, adj)
        #h = self.update(torch.cat((h,m),2))
        #h = self.update(h)
        #c = torch.cat([att(h, adj)[0] for att in self.attentions], dim=2)
        #for att in self.attentions3:
            #attn = att(h, adj)[1]
        #c = self.update(c)

        #cc = torch.cat([att(c, adj)[0] for att in self.attentions2], dim=2)
        #cc = self.update2(cc)

        #ccc = torch.cat([att(cc, adj)[0] for att in self.attentions3], dim=2)
        #h = self.update3(ccc)

        #xxx = []
        #for att in self.attentions3:
        #    attn = att(c, adj)[1]
            #ccc.append(xx)



        """
        for att in self.attentions:
            m,attn = att(h,adj)
            print(att(h,adj))
            h = torch.cat([h,m],dim=2)
        h = self.update(h)
        for att2 in self.attentions2:
            m,attn = att(h,adj)
            h = torch.cat([h,m],dim=2)
        h = self.update(h)
        for att3 in self.attentions3:
            m,attn = att(h,adj)
            h = torch.cat([h,m],dim=2)
        h = self.update(h)
        """


        #for k in range(self.K):
        #    m, attn = self.messages(h, mask=mask, return_attn=True) # should be <batch_size/N,N,self.embed_dim>
        #    h = self.update(torch.cat((h,m),2)) # should be <batch_size/N,N,self.h_dim>
        h = h.transpose(0,1).contiguous().view(-1,self.h_dim)
        #print(attn.shape)

        self.attn_mat = attn.squeeze().detach().cpu().numpy()
        #print(self.attn_mat.shape)
        return h # should be <batch_size, self.h_dim> again

    def forward(self, inp, state, mask=None):
        raise NotImplementedError

    def _value(self, x):
        return self.value_head(x)

    def _policy(self, x):
        return self.policy_head(x)

    def act(self, inp, state, mask=None, deterministic=False):
        #print(inp.shape)
        x = self._fwd(inp)
        #value = self._value(x)
        value = self.attention_value(x)
        dist = self.dist(self._policy(x))
        if deterministic:
            action = dist.mode()
        else:
            action = dist.sample()
        action_log_probs = dist.log_probs(action).view(-1,1)
        return value,action,action_log_probs,state

    def evaluate_actions(self, inp, state, mask, action):
        x = self._fwd(inp)
        #value = self._value(x)
        value = self.attention_value(x)
        dist = self.dist(self._policy(x))
        action_log_probs = dist.log_probs(action)
        dist_entropy = dist.entropy().mean()
        return value,action_log_probs,dist_entropy,state

    def get_value(self, inp, state, mask):
        x = self._fwd(inp)
        value = self._value(x)
        return value

class Attention_value(nn.Module):
    def __init__(self, hidden_dim=128, embed_dim=None, nonlin=nn.ReLU):
        super().__init__()

        self.h_dim = hidden_dim
        #self.h_dim = 64
        self.nonlin = nonlin
        self.embed_dim = self.h_dim if embed_dim is None else embed_dim
        self.value_network_base1 = nn.Sequential(nn.Linear(self.h_dim, self.h_dim),         #  actor_critic网络 无需修改
                                        self.nonlin(inplace=True),
                                        nn.Linear(self.h_dim,1))
        self.value_network_base2 = nn.Sequential(nn.Linear(self.h_dim, self.h_dim),         #  actor_critic网络 无需修改
                                        self.nonlin(inplace=True),
                                        nn.Linear(self.h_dim,1))
        self.value_network_base3 = nn.Sequential(nn.Linear(self.h_dim, self.h_dim),         #  actor_critic网络 无需修改
                                        self.nonlin(inplace=True),
                                        nn.Linear(self.h_dim,1))
        self.value_network_base4 = nn.Sequential(nn.Linear(self.h_dim, self.h_dim),         #  actor_critic网络 无需修改
                                        self.nonlin(inplace=True),
                                        nn.Linear(self.h_dim,1))
        self.value_network_base5 = nn.Sequential(nn.Linear(self.h_dim, self.h_dim),         #  actor_critic网络 无需修改
                                        self.nonlin(inplace=True),
                                        nn.Linear(self.h_dim,1))
        self.value_network_base6 = nn.Sequential(nn.Linear(self.h_dim, self.h_dim),         #  actor_critic网络 无需修改
                                        self.nonlin(inplace=True),
                                        nn.Linear(self.h_dim,1))
        self.messages_3 = MultiHeadAttention(n_heads=1,input_dim=1,embed_dim=1)

    def forward(self, inp):
        h = inp
        h_v1 = self.value_network_base1(h)  # 12,1
        h_v2 = self.value_network_base2(h)
        h_v3 = self.value_network_base3(h)
        h_v4 = self.value_network_base4(h)
        h_v5 = self.value_network_base5(h)
        h = h_v1 + h_v2 + h_v3 + h_v4 + h_v5
        h = h/5
        """
        size0,size1 = h_v1.size()
        h_v0 = torch.zeros(size0,size1)
        value_set = torch.stack((h_v0,h_v1,h_v2,h_v3),dim = 1)
        #print(value_set)
        mask_v = torch.full(size=(value_set.size(0),4,4),fill_value=1,dtype=torch.uint8)
        for i in range(mask_v.size(0)):
            for j in range(4):
                if j == 0:
                   mask_v[i][j][1] = 0
                   mask_v[i][j][2] = 0
                   mask_v[i][j][3] = 0
                if j == 1:
                   mask_v[i][j][0] = 0
                if j == 2:
                   mask_v[i][j][0] = 0
                if j == 3:
                   mask_v[i][j][0] = 0



        h,attn = self.messages_3(value_set, mask=mask_v, return_attn=True)
        h = h[:,0,:]
        #print(h.size(),'4')
        """
        return h


class MultiHeadAttention(nn.Module):
    # taken from https://github.com/wouterkool/attention-tsp/blob/master/graph_encoder.py
    def __init__(
            self,
            n_heads,
            input_dim,
            embed_dim=None,
            val_dim=None,
            key_dim=None
    ):
        super(MultiHeadAttention, self).__init__()

        if val_dim is None:
            assert embed_dim is not None, "Provide either embed_dim or val_dim"
            val_dim = embed_dim // n_heads
        if key_dim is None:
            key_dim = val_dim

        self.n_heads = n_heads
        self.input_dim = input_dim
        self.embed_dim = embed_dim
        self.val_dim = val_dim
        self.key_dim = key_dim

        self.norm_factor = 1 / math.sqrt(key_dim)  # See Attention is all you need

        self.W_query = nn.Parameter(torch.Tensor(n_heads, input_dim, key_dim))
        self.W_key = nn.Parameter(torch.Tensor(n_heads, input_dim, key_dim))
        self.W_val = nn.Parameter(torch.Tensor(n_heads, input_dim, val_dim))

        if embed_dim is not None:
            self.W_out = nn.Parameter(torch.Tensor(n_heads, key_dim, embed_dim))

        self.init_parameters()

    def init_parameters(self):

        for param in self.parameters():
            stdv = 1. / math.sqrt(param.size(-1))
            param.data.uniform_(-stdv, stdv)

    def forward(self, q, h=None, mask=None, return_attn=False):
        """
        :param q: queries (batch_size, n_query, input_dim)
        :param h: data (batch_size, graph_size, input_dim)
        :param mask: mask (batch_size, n_query, graph_size) or viewable as that (i.e. can be 2 dim if n_query == 1)
        Mask should contain 1 if attention is not possible (i.e. mask is negative adjacency)
        :return:
        """
        if h is None:
            h = q  # compute self-attention

        # h should be (batch_size, graph_size, input_dim)
        batch_size, graph_size, input_dim = h.size()
        n_query = q.size(1)
        assert q.size(0) == batch_size
        assert q.size(2) == input_dim
        assert input_dim == self.input_dim, "Wrong embedding dimension of input"

        hflat = h.contiguous().view(-1, input_dim)
        qflat = q.contiguous().view(-1, input_dim)

        # last dimension can be different for keys and values
        shp = (self.n_heads, batch_size, graph_size, -1)
        shp_q = (self.n_heads, batch_size, n_query, -1)

        # Calculate queries, (n_heads, n_query, graph_size, key/val_size)
        Q = torch.matmul(qflat, self.W_query).view(shp_q)
        # Calculate keys and values (n_heads, batch_size, graph_size, key/val_size)
        K = torch.matmul(hflat, self.W_key).view(shp)
        V = torch.matmul(hflat, self.W_val).view(shp)

        # Calculate compatibility (n_heads, batch_size, n_query, graph_size)
        compatibility = self.norm_factor * torch.matmul(Q, K.transpose(2, 3))
        # Optionally apply mask to prevent attention
        if mask is not None:
            mask = mask.view(1, batch_size, n_query, graph_size).expand_as(compatibility)
            compatibility[mask] = -math.inf

        attn = F.softmax(compatibility, dim=-1)

        # If there are nodes with no neighbours then softmax returns nan so we fix them to 0
        if mask is not None:
            attnc = attn.clone()
            attnc[mask] = 0
            attn = attnc

        heads = torch.matmul(attn, V)

        out = torch.mm(
            heads.permute(1, 2, 0, 3).contiguous().view(-1, self.n_heads * self.val_dim),
            self.W_out.view(-1, self.embed_dim)
        ).view(batch_size, n_query, self.embed_dim)

        if return_attn:
            return out, attn
        return out
